import torch
import argparse
import numpy as np
import math
import random
import os
import yaml
import re
import sys
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn as nn
from torch.utils import data
from torch import autograd
from torch.autograd import Variable
from codebase import utils as ut
from codebase.models import nns 
from torchvision import datasets, transforms
from PIL import Image
from vae_celeba import CauFVAE
from torchvision.utils import save_image
from datetime import datetime
from Conditioners import *
cuda = torch.cuda.is_available()

parser = argparse.ArgumentParser(description='Disentangled representation based on CauF-VAE')
parser.add_argument('--dataset', type=str, default='celeba', choices=['celeba', 'pendulum'])
parser.add_argument('--data_dir', type=str, default='./causal_data', help='data directory')
parser.add_argument('--labels', type=str, default='smile', help='name of the underlying structure')
parser.add_argument('--epoch_max',   type=int, default=101,    help="Number of training epochs")
parser.add_argument('--iter_save',   type=int, default=20, help="Save model every n epochs")
parser.add_argument("-emb_net", default=[200,200, 2], nargs="+", type=int, help="NN layers of embedding")
parser.add_argument('--dim',       type=int, default=100,     help="total latent dimension")
parser.add_argument('--dim1',       type=int, default=0,     help="latent dimension1")
parser.add_argument('--dim2',       type=int, default=6,     help="DAG latent dimension")
parser.add_argument('--dim3',       type=int, default=94,     help="latent dimension2")
parser.add_argument('--n_layers_CauF', type=int, default=1,     help="Number of Causal flow")
parser.add_argument('--sup_coef', type=float, default=5, help='coefficient of the supervised regularizer')
parser.add_argument('--sup_prop', type=float, default=1., help='proportion of supervised labels')
parser.add_argument('--sup_type', type=str, default='ce', choices=['ce', 'l2'])
parser.add_argument('--lr_d', type=float, default=3e-4, help='lr of decoder')
parser.add_argument('--lr_e', type=float, default=3e-4, help='lr of encoder')
parser.add_argument('--lr_f', type=float, default=3e-4, help='lr of flow')
parser.add_argument('--lr_a', type=float, default=1e-3, help='lr of adjacency matrix')
parser.add_argument('--beta1', type=float, default=0.2)
parser.add_argument('--beta2', type=float, default=0.999)
parser.add_argument('--beta3', type=float, default=0.2)
parser.add_argument('--beta4', type=float, default=0.999)
parser.add_argument('--lr_p', type=int, default=3e-4, help='lr of prior')
parser.add_argument('--n_traverse', type=int, default=10)
parser.add_argument('--n_intervene', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=128)
args = parser.parse_args([])
device = torch.device("cuda:0" if(torch.cuda.is_available()) else "cpu")

if 'pendulum' in args.dataset:
        label_idx = range(4)
else:
    if args.labels == 'smile':
            label_idx = [31, 20, 19, 21, 23, 13]
    elif args.labels == 'attractive':
            label_idx = [39, 20, 28, 18, 13, 3]
            
num_label = len(label_idx)

A = torch.zeros((num_label, num_label))
if args.labels == 'smile':
    A[2, 0] = 1
    A[3, 0] = 1
    A[4, 0] = 1
    A[4, 1] = 1
    A[5, 0] = 1
elif args.labels == 'attractive':
    A[2, 0] = 1
    A[3, 0] = 1
    A[4, 0] = 1
    A[5, 0] = 1
    A[2, 1] = 1
    A[3, 1] = 1
elif args.labels == 'pend':
    A[2, 0] = 1
    A[2, 1] = 1
    A[3, 0] = 1
    A[3, 1] = 1
A=A.to(device)

conditioner_args = {"in_size": args.dim2, "hidden": args.emb_net[:-1], "out_size": args.emb_net[-1], "A_prior":A}
conditioner = DAGConditioner(**conditioner_args)

vae = CauFVAE(
        DAGConditioner, 
        conditioner_args,
        args.dim,
        args.n_layers_CauF, 
        args.dim1,
        args.dim2,
        args.dim3,
        nn='network'
        ).to(device)

A_optimizer = None
flow_optimizer = None
enc_param = vae.enc.parameters()
dec_param = list(vae.dec.parameters())
flow_param = list(vae.flow.parameters())
prior_param = list(vae.prior.parameters())

A_optimizer = optim.Adam(flow_param[0:1], lr=args.lr_a)
flow_optimizer = optim.Adam(flow_param[1:], lr=args.lr_f, betas=(args.beta1, args.beta2))
prior_optimizer = optim.Adam(prior_param, lr=args.lr_p, betas=(args.beta3, args.beta4))
encoder_optimizer = optim.Adam(enc_param, lr=args.lr_e, betas=(args.beta1, args.beta2))
decoder_optimizer = optim.Adam(dec_param, lr=args.lr_d, betas=(args.beta1, args.beta2))

if not os.path.exists('./figs_celeba_smile/'): 
	os.makedirs('./figs_celeba_smile/')

test_loader = None
train_loader = None
trans_f = transforms.Compose([
            transforms.CenterCrop(128),
            transforms.Resize((64, 64)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
train_set = datasets.CelebA(args.data_dir, split='train', download=True, transform=trans_f)
train_dataset = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, pin_memory=False,
                                                   drop_last=True, num_workers=8)
  

test_set = datasets.CelebA(args.data_dir, split='test', download=True, transform=trans_f)
test_dataset = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False, pin_memory=False,
                                                   drop_last=True, num_workers=8)

i=0
for (x, label) in test_dataset:
    test_data=x
    i=i+1
    if i==44:
        break                                                  
save_image(test_data, './figs_celeba_smile/test_data_true.png', normalize = True) 

def save_model_by_name(model, global_step):
	save_dir = './checkpoints_smile'
	if not os.path.exists(save_dir):
		os.makedirs(save_dir)
	file_path = os.path.join(save_dir, 'model-{:05d}.pt'.format(global_step))
	state = model.state_dict()
	torch.save(state, file_path)
	print('Saved to {}'.format(file_path))

def make_folder(path):
    if not os.path.exists(path):
        os.makedirs(path)

for epoch in range(args.epoch_max):
	vae.train()
    
	with torch.no_grad():
		for conditioner in vae.getConditioners():
			conditioner.constrainA(zero_threshold=0.)
            
	vae.to(device)           
	total_loss = 0
	total_rec = 0
	total_kl = 0
	total_sup_loss=0
    
	for batch_idx, (x, label) in enumerate(train_dataset):
		x = x.to(device)
		label = label.to(device)
		label = label[:, :][:, label_idx].float()
		for i in range(len(label_idx)):
			for j in range(args.batch_size):
				if label[j,i]==-1:
					label[j,i] = label[j,i]+1
				else:
					label[j,i] = label[j,i]

		A_optimizer.zero_grad()
		encoder_optimizer.zero_grad()
		decoder_optimizer.zero_grad()
		flow_optimizer.zero_grad()
		prior_optimizer.zero_grad()  
		rec, kl,suploss,L, reconstructed_image= vae(x, label,  args.sup_type, args.sup_prop, args.sup_coef)   
		L.backward()

		encoder_optimizer.step()
		decoder_optimizer.step()
		prior_optimizer.step()
        
		vae.set_zero_grad()
		A_optimizer.step()
		flow_optimizer.step()
        
		total_loss += L.item()
		total_kl += kl.item() 
		total_rec += rec.item() 
		total_sup_loss += suploss.item()
		m = len(train_dataset)
        
	reconstructed_image=reconstructed_image[1]
	save_image(x[1], './figs_celeba_smile/reconstructed_image_true_{}.png'.format(epoch), normalize = True) 
	save_image(reconstructed_image, './figs_celeba_smile/reconstructed_image_{}.png'.format(epoch), normalize = True)
    
	vae.eval()
	save_dir = './results_celeba_smile/{}/{}/'.format(args.dataset, args.labels)
	make_folder(save_dir)
	with torch.no_grad():
	# Traverse
		sample = vae.traverse(test_data,gap=2,n=args.n_traverse).to(device)
		save_image(sample, save_dir + 'trav_' + str(epoch)  + '.png', normalize=True, nrow=args.n_traverse)
		del sample
	# intervene
		intervene_x = vae.intervene(test_data,gap=2,n=args.n_intervene).to(device)
		save_image(intervene_x, save_dir + 'intervene_' + str(epoch)  + '.png', normalize=True, nrow=args.n_intervene)
		del intervene_x 
        
	print(str(epoch)+' loss:'+str(total_loss/m)+' kl:'+str(total_kl/m)+' sup:'+str(total_sup_loss/m)+' rec:'+str(total_rec/m)+' m:' + str(m))
	if epoch % args.iter_save == 0:
		save_model_by_name(vae, epoch)